/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#pragma once

#include <CL/cl2.hpp>

#define OPENCL_CHECK(F)                                                \
    do                                                                 \
    {                                                                  \
        auto ret = F;                                                  \
        if (ret != CL_SUCCESS)                                         \
        {                                                              \
            auto str = "OpenCL error: " + std::string(#F ": ") + std::to_string(ret);     \
            if (!std::uncaught_exceptions())                           \
                throw std::runtime_error{str};                         \
            else                                                       \
                std::cerr << "Error: " << str << std::endl;            \
        }                                                              \
    } while (0)

namespace opencl {

inline std::vector<cl::Device> get_gpus()
{
    std::vector<cl::Platform> platforms;
    cl::Platform::get(&platforms);
    if(platforms.empty())
        throw std::runtime_error{"No OpenCL platforms found"};

    std::vector<cl::Device> gpus;
    for(auto& platform : platforms) {
        platform.getDevices(CL_DEVICE_TYPE_GPU, &gpus);
    }

    return gpus;
}

inline cl::Device get_gpu(size_t i)
{
    auto gpus = get_gpus();
    if(i >= gpus.size())
        throw std::runtime_error{"Invalid OpenCL device index " + std::to_string(i)};
    return gpus[i];
}

inline cl::Program make_program(
    const cl::Context& ctx,
    const cl::Device& device,
    const void* data,
    size_t size,
    std::string options = ""
) {
    cl::Program::Sources sources;
    sources.emplace_back(reinterpret_cast<const char*>(data), size_t(size));
    cl::Program program(ctx, sources);
    if(program.build({device}, options.empty() ? nullptr : options.c_str()) != CL_SUCCESS) {
        throw std::runtime_error{"Build failed: " + program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device)};
    }
    return program;
}

inline auto memcpy_to_device(const cl::Buffer& dst, const void* src, size_t size, const cl::CommandQueue& q)
{
    return q.enqueueWriteBuffer(dst, CL_FALSE, 0, size, src);
}

inline auto memcpy_to_host(void* dst, const cl::Buffer& src, size_t size, const cl::CommandQueue& q)
{
    return q.enqueueReadBuffer(src, CL_FALSE, 0, size, dst);
}

template<typename T>
inline auto copy_to_device(const cl::Buffer& dst, const T* src, size_t size, const cl::CommandQueue& q)
{
    return memcpy_to_device(dst, src, size * sizeof(T), q);
}

template<typename T>
inline auto copy_to_host(T* dst, const cl::Buffer& src, size_t size, const cl::CommandQueue& q)
{
    return memcpy_to_host(dst, src, size * sizeof(T), q);
}

namespace detail {

inline void bind_kernel_step(cl::Kernel& kernel, size_t)
{}

template<typename Arg, typename ...Args>
inline void bind_kernel_step(cl::Kernel& kernel, size_t n, Arg&& arg, Args&& ...args)
{
    kernel.setArg(n, std::forward<Arg>(arg));
    bind_kernel_step(kernel, n + 1, std::forward<Args>(args)...);
}

} // namespace detail

template<typename... Args>
inline void bind_kernel(cl::Kernel& kernel, Args&& ...args)
{
    detail::bind_kernel_step(kernel, 0, std::forward<Args>(args)...);
}

inline auto run_kernel(const cl::Kernel& kernel, size_t n_blocks, size_t block_size, const cl::CommandQueue& q)
{
    return q.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(n_blocks * block_size), cl::NDRange(block_size));
}

template<typename T>
class DualBuffer {
public:
    DualBuffer(const cl::Context& ctx, size_t size)
    : device_{ctx, CL_MEM_READ_WRITE, size * sizeof(T)}
    , host_(std::make_unique<T[]>(size))
    , size_{size}
    {}

    size_t size() const { return size_; }

    const cl::Buffer& device() const { return device_; }
    const T* host() const { return host_.get(); }
    T* host() { return host_.get(); }

    void copy_to_host(const cl::CommandQueue& q) {
        ::opencl::copy_to_host(host(), device(), size(), q);
    }

    void copy_to_device(const cl::CommandQueue& q) {
        ::opencl::copy_to_device(device(), host(), size(), q);
    }

private:
    cl::Buffer device_;
    std::unique_ptr<T[]> host_;
    size_t size_;
};

template<typename T>
inline cl::Buffer make_device_buffer(const cl::Context& ctx, size_t n)
{
    return cl::Buffer(ctx, CL_MEM_READ_WRITE, n * sizeof(T));
}

} // namespace opencl